[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836
[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836pedramr wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
…itectures Replace the exact-match compute capability allowlist with a >= sm80 range check, matching flash-attn's own gate: Dao-AILab/flash-attention@bbb21d6 The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89 (L40S), sm86 (A40), and others where FA2 supports head_dim up to 256. The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear to be supported based on flash-attn's >= sm80 guarantee. Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
Greptile SummaryThis PR fixes a bug in Confidence Score: 5/5Safe to merge — minimal, targeted bug fix with correct logic and no regressions introduced. The change removes a single, clearly erroneous allowlist condition. The new condition is logically equivalent to flash-attn's own gate given the earlier < sm80 guard already disables FA2 before this point. The dead-code concern flagged in the previous review thread is fully resolved by removing the branch entirely. No new logic is added, and the log message is updated consistently. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called] --> B{device_compute_capability < sm80?}
B -- Yes --> C[use_flash_attention_2 = False]
B -- No --> D{use_flash_attention_2 AND FA2 installed?}
D -- No --> G[Skip FA2 head_dim check]
D -- Yes --> E{head_dim_qk > 256\nOR head_dim_qk % 8 != 0?}
E -- Yes --> F[use_flash_attention_2 = False\nlog debug message]
E -- No --> H[FA2 remains enabled]
C --> I[Continue backend selection]
F --> I
G --> I
H --> I
style C fill:#f88,stroke:#c00
style F fill:#f88,stroke:#c00
style H fill:#8f8,stroke:#090
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
/te-ci L0 |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Description
The
head_dim > 192gate for FlashAttention 2 inget_attention_backendused an exact-matchcompute capability allowlist:
(8,0), (9,0), (10,0), (12,0). This excluded sm103 (B300/GB300),sm89 (L40S/RTX 4090), sm86 (A40/RTX 3090), and other valid architectures where flash-attn
supports head_dim up to 256.
This PR replaces the allowlist with a
>= sm80range check, matching flash-attn's own gate:Dao-AILab/flash-attention@bbb21d6
The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear
to be supported based on flash-attn's >= sm80 guarantee.
Type of change
Changes
device_compute_capability < (8, 0)range checksm80/90/100+tosm80+